{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "w-_k4j9wm5GD"
      },
      "source": [
        "# Efficiently serving Large Language Models in 2bit with `aqlm` and `transformers` compiled into a CUDA graph\n",
        "\n",
        "<a target=\"_blank\" href=\"https://colab.research.google.com/github/Vahe1994/AQLM/blob/main/notebooks/aqlm_cuda_graph.ipynb\">\n",
        "  <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
        "</a>\n",
        "\n",
        "Welcome to this notebook that goes through the recent `aqlm` integration that introduces efficient GPU utilization when serving LLMs quantized to 2bit.\n",
        "\n",
        "In this notebook, we will learn how to load a large model in 2bit (`Llama-2-7b`) and comile a CUDA graph of it, to circumvent Python overhead whem serving the model.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6egoxPVyckBF"
      },
      "source": [
        "**Install the `aqlm` library**\n",
        "- The only extra dependency to run AQLM models.\n",
        "- Add `[gpu]` to install the required CUDA specific dependencies.\n",
        "- To use nice features like `device_map` you'll need to install accelerate. To properly support AQLM, you'd have to install the latest version straight from their GitHub (to catch [PR#2376](https://github.com/huggingface/accelerate/pull/2376))."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "A584OAwRWGks"
      },
      "outputs": [],
      "source": [
        "%%capture\n",
        "!pip install aqlm[gpu]>=1.1.0\n",
        "!pip install accelerate>=0.27.0\n",
        "!pip install transformers>=4.41.0"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hTfcs4lrc1x4"
      },
      "source": [
        "**Load the model as usual**\n",
        "\n",
        "The tokenizer is just a normal `Llama 2` tokenizer."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lecaItWkVpIC"
      },
      "outputs": [],
      "source": [
        "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
        "\n",
        "quantized_model = AutoModelForCausalLM.from_pretrained(\n",
        "    \"ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf\",\n",
        "    torch_dtype=\"auto\", device_map=\"auto\", low_cpu_mem_usage=True,\n",
        ")\n",
        "tokenizer = AutoTokenizer.from_pretrained(\"ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "39QpRiPbcBYa"
      },
      "source": [
        "Do a few forward passes to load CUDA and automatically compile the kernels. It's done separately here for it not to affect the generation speed benchmark below."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Ii-mWRdQZCOF"
      },
      "outputs": [],
      "source": [
        "%%capture\n",
        "output = quantized_model.generate(tokenizer(\"\", return_tensors=\"pt\")[\"input_ids\"].cuda(), max_new_tokens=10)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zOQfeb_ScIyb"
      },
      "source": [
        "**Measure generation speed**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Q2CZ9QrA1S0P"
      },
      "outputs": [],
      "source": [
        "import time"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hyl4uCxTdmKi"
      },
      "outputs": [],
      "source": [
        "start = time.perf_counter()\n",
        "output = quantized_model.generate(tokenizer(\"I'm AQLM, \", return_tensors=\"pt\")[\"input_ids\"].cuda(), min_new_tokens=128, max_new_tokens=128)\n",
        "end = time.perf_counter()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "-yaWulHS1eqa",
        "outputId": "e940864a-0639-4113-9071-4659b32939fe"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Generating at 8.5 tok/s\n"
          ]
        }
      ],
      "source": [
        "print(f\"Generating at {128 / (end - start):.1f} tok/s\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nvShqlguccep"
      },
      "source": [
        "**Check that the output is what one would expect from Llama**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "SsOmDVBvXobJ",
        "outputId": "b225a155-28bb-462e-dcae-fb1527bd78a6"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "<s> I'm AQLM, 20, and I'm from the UK. I'm a student at the University of Nottingham, studying English and Creative Writing. I'm a huge fan of the Harry Potter series, and I'm also a huge fan of the Marvel Cinematic Universe. I'm also a huge fan of the DC Extended Universe, and I'm also a huge fan of the Star Wars franchise. I'm also a huge fan of the Marvel Cinematic Universe, and I'm also a huge fan of the DC Extended Universe, and I'm also a huge fan\n"
          ]
        }
      ],
      "source": [
        "print(tokenizer.decode(output[0]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "p4ON1sMP2c2P"
      },
      "source": [
        "### Compile a CUDA graph\n",
        "\n",
        "Note that `transformers` generation itself is not the fastest implementation and it's heavily influenced by CPU capabilities of _Google Colab_. We'll deal with it by using static caches and compiling the model's forward pass into a homogeneous CUDA graph, effectively removing python's overhead."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "y70ZZqSi27uM"
      },
      "source": [
        "**We'll have to implement the logic around forward passes on our own since CUDA graphs are not yet integrated into transformers**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SoNpCyT72ffp"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "\n",
        "def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_values):\n",
        "    logits = model(\n",
        "        cur_token,\n",
        "        position_ids=input_pos,\n",
        "        cache_position=cache_position,\n",
        "        past_key_values=past_key_values,\n",
        "        return_dict=False,\n",
        "        use_cache=True\n",
        "    )[0]\n",
        "    new_token = torch.argmax(logits[:, -1], dim=-1)[:, None]\n",
        "    return new_token\n",
        "\n",
        "MAX_NEW_TOKENS = 128"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Aly3cUrw3rvv"
      },
      "source": [
        "**Setup static KV cache for generation**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "a9nlMa9S3S8h"
      },
      "outputs": [],
      "source": [
        "from transformers import StaticCache\n",
        "\n",
        "input_ids = tokenizer(\"I'm AQLM, \", return_tensors=\"pt\").to(\"cuda\")[\"input_ids\"]\n",
        "seq_length = input_ids.shape[1]\n",
        "\n",
        "past_key_values = StaticCache(\n",
        "    quantized_model.config,\n",
        "    1,\n",
        "    seq_length + MAX_NEW_TOKENS * 2 + 1,\n",
        "    quantized_model.device,\n",
        "    quantized_model.dtype\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aPSedHaQ2gdN"
      },
      "source": [
        "**Allocate token ids to be generated and copy prefix ids**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LSbQq6As343w"
      },
      "outputs": [],
      "source": [
        "cache_position = torch.arange(seq_length, device=\"cuda\")\n",
        "generated_ids = torch.zeros(1, seq_length + MAX_NEW_TOKENS * 2, dtype=torch.int, device=\"cuda\")\n",
        "generated_ids[:, cache_position] = input_ids.to(\"cuda\").to(torch.int)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xQu-Ppwo4Geu"
      },
      "source": [
        "**Do a forward pass to fill the prefix cache and compile the kernels if necessary**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Spi5_rXb3_IP"
      },
      "outputs": [],
      "source": [
        "logits = quantized_model(\n",
        "    input_ids, cache_position=cache_position, past_key_values=past_key_values,return_dict=False, use_cache=True\n",
        ")[0]\n",
        "next_token = torch.argmax(logits[:, [-1]], dim=-1).to(torch.int)\n",
        "generated_ids[:, [seq_length]] = next_token"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "j06LyMyo4SE-"
      },
      "source": [
        "**Compile the CUDA graph with `torch.compile` and appply the forward pass repeatedly to generate text**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mRb_1-N-4KsA"
      },
      "outputs": [],
      "source": [
        "with torch.no_grad():\n",
        "    # Compile the CUDA graph\n",
        "    decode_one_tokens = torch.compile(decode_one_tokens, mode=\"reduce-overhead\", fullgraph=True)\n",
        "\n",
        "    # Generate tokens one by one\n",
        "    cache_position = torch.tensor([seq_length + 1], device=\"cuda\")\n",
        "    for _ in range(1, MAX_NEW_TOKENS):\n",
        "        with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):\n",
        "            next_token = decode_one_tokens(quantized_model, next_token.clone(), None, cache_position, past_key_values)\n",
        "            generated_ids[:, cache_position] = next_token.int()\n",
        "        cache_position += 1"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "RYX1klVj-u7-",
        "outputId": "ed96681b-da97-4886-ffb4-d3a050ddcc5e"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "<s> I'm AQLM, 20, and I'm from the UK. I'm a student at the University of Nottingham, studying English and Creative Writing. I'm a huge fan of the Harry Potter series, and I'm also a huge fan of the Marvel Cinematic Universe. I'm also a huge fan of the DC Extended Universe, and I'm also a huge fan of the Star Wars franchise. I'm also a huge fan of the Marvel Cinematic Universe, and I'm also a huge fan of the DC Extended Universe, and I'm also a huge fan<unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk>\n"
          ]
        }
      ],
      "source": [
        "print(tokenizer.decode(generated_ids[0]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "p60kWdwy5BLF"
      },
      "source": [
        "**Continue the generation mesuring the speed**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mXGckNom4jBy"
      },
      "outputs": [],
      "source": [
        "start = time.perf_counter()\n",
        "with torch.no_grad():\n",
        "    for _ in range(MAX_NEW_TOKENS):\n",
        "        with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):\n",
        "            next_token = decode_one_tokens(quantized_model, next_token.clone(), None, cache_position, past_key_values)\n",
        "            generated_ids[:, cache_position] = next_token.int()\n",
        "        cache_position += 1\n",
        "end = time.perf_counter()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "bUNO9J4-5UOa",
        "outputId": "dda822ba-e7fb-483d-9f3a-215d62175b14"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Generating at 24.0 tok/s\n"
          ]
        }
      ],
      "source": [
        "print(f\"Generating at {128 / (end - start):.1f} tok/s\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AXWMpf897N0H"
      },
      "source": [
        "We achieved a **3x** speedup over normal generation using CUDA graphs, and the generated text is almost identical, as it should be."
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
